import numpy as np
import scipy
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import re
import time
import os
from collections import deque
from utils import df1, df2
import argparse
from pathlib import Path

def LMC(config):
    h = config['step size']
    df = config['grad potential']
    n = config['num samples']
    d = config['dimension']
    ic = config['initial condition']
    n_iter = int(config['T'] / h)
    k = config['keep last k']
    fn = config['stats function']

    hist = np.zeros((k, d))

    x = np.zeros((n, d)) + ic
    for i in range(1, n_iter + 1):
        x -= h * df(x) + np.sqrt(2*h) * np.random.randn(n, d)
        if i >= n_iter + 1 - k:
            j = i - (n_iter + 1 - k)
            hist[j] = fn(x)

    return hist

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--potential", type=str)
    parser.add_argument("--seed", type=int)
    args = parser.parse_args()

    np.random.seed(args.seed)

    config = {
        'step size': 1e-4,
        'num samples': int(1e4),
        'grad potential': None,
        'dimension': 10,
        'initial condition': 1,
        'T': 10,
        'stats function': lambda x: x.mean(axis=0),
        'keep last k': 100
    }

    d = config['dimension']
    if args.potential == 'log-sum-exp':
        config['grad potential'] = df1
    elif args.potential == 'cosine':
        config['grad potential'] = lambda x: df2(x, d)

    start = time.time()
    hist = LMC(config)
    end = time.time()

    path = f"../scratch/benchmark/{args.potential}"
    Path(path).mkdir(parents=True, exist_ok=True)
    np.save(f'{path}/d={d}_part={args.seed}.npy', hist)